{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "前面的博客中我们说过，在加载数据和预处理数据时使用tf.data.Dataset对象将极大将我们从建模前的数据清理工作中释放出来，那么，怎么将自定义的数据集加载为DataSet对象呢？这对很多新手来说都是一个难题，因为绝大多数案例教学都是以mnist数据集作为例子讲述如何将数据加载到Dataset中，而英文资料对这方面的介绍隐藏得有点深。本文就来捋一捋如何加载自定义的图片数据集实现图片分类，后续将继续介绍如何加载自定义的text、mongodb等数据。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果你已有数据集，那么，请将所有数据存放在同一目录下，然后将不同类别的图片分门别类地存放在不同的子目录下,目录树如下所示：\n",
    "\n",
    "$ tree flower_photos -L 1\n",
    "\n",
    "flower_photos  \n",
    "├── daisy  \n",
    "├── dandelion  \n",
    "├── LICENSE.txt  \n",
    "├── roses  \n",
    "├── sunflowers  \n",
    "└── tulips  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "所有的数据都存放在flower_photos目录下，每一个子目录（daisy、dandelion等等）存放的都是一个类别的图片。如果你已有自己的数据集，那就按上面的结构来存放，如果没有，想操作学习一下，你可以通过下面代码下载上述图片数据集："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import pathlib\n",
    "data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',\n",
    "                                         fname='flower_photos', untar=True)\n",
    "data_root = pathlib.Path(data_root_orig)\n",
    "print(data_root)  # 打印出数据集所在目录"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下载好后，建议将整个flower_photos目录移动到项目根目录下。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3670"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import random\n",
    "import pathlib\n",
    "data_path = pathlib.Path('./data/flower_photos')\n",
    "all_image_paths = list(data_path.glob('*/*'))  \n",
    "all_image_paths = [str(path) for path in all_image_paths]  # 所有图片路径的列表\n",
    "random.shuffle(all_image_paths)  # 打散\n",
    "\n",
    "image_count = len(all_image_paths)\n",
    "image_count"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "查看一下前5张："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['data/flower_photos/sunflowers/9448615838_04078d09bf_n.jpg',\n",
       " 'data/flower_photos/roses/15222804561_0fde5eb4ae_n.jpg',\n",
       " 'data/flower_photos/daisy/18622672908_eab6dc9140_n.jpg',\n",
       " 'data/flower_photos/roses/459042023_6273adc312_n.jpg',\n",
       " 'data/flower_photos/roses/16149016979_23ef42b642_m.jpg']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_image_paths[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "读取图片的同时，我们也不能忘记图片与标签的对应，要创建一个对应的列表来存放图片标签，不过，这里所说的标签不是daisy、dandelion这些具体分类名，而是整型的索引，毕竟在建模的时候y值一般都是整型数据，所以要创建一个字典来建立分类名与标签的对应关系："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "label_names = sorted(item.name for item in data_path.glob('*/') if item.is_dir())\n",
    "label_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "label_to_index = dict((name, index) for index, name in enumerate(label_names))\n",
    "label_to_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/flower_photos/sunflowers/9448615838_04078d09bf_n.jpg  --->   3\n",
      "data/flower_photos/roses/15222804561_0fde5eb4ae_n.jpg  --->   2\n",
      "data/flower_photos/daisy/18622672908_eab6dc9140_n.jpg  --->   0\n",
      "data/flower_photos/roses/459042023_6273adc312_n.jpg  --->   2\n",
      "data/flower_photos/roses/16149016979_23ef42b642_m.jpg  --->   2\n"
     ]
    }
   ],
   "source": [
    "for image, label in zip(all_image_paths[:5], all_image_labels[:5]):\n",
    "    print(image, ' --->  ', label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "好了，现在我们可以创建一个Dataset了："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "不过，这个ds可不是我们想要的，毕竟，里面的元素只是图片路径，所以我们要进一步处理。这个处理包含读取图片、重新设置图片大小、归一化、转换类型等操作，我们将这些操作统统定义到一个方法里："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_and_preprocess_from_path_label(path, label):\n",
    "    image = tf.io.read_file(path)  # 读取图片\n",
    "    image = tf.image.decode_jpeg(image, channels=3)\n",
    "    image = tf.image.resize(image, [192, 192])  # 原始图片大小为(266, 320, 3)，重设为(192, 192)\n",
    "    image /= 255.0  # 归一化到[0,1]范围\n",
    "    return image, label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_label_ds  = ds.map(load_and_preprocess_from_path_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<MapDataset shapes: ((192, 192, 3), ()), types: (tf.float32, tf.int32)>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "image_label_ds "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这时候，其实就已经将自定义的图片数据集加载到了Dataset对象中，不过，我们还能秀，可以继续shuffle随机打散、分割成batch、数据repeat操作。这些操作有几点需要注意：  \n",
    "（1）先shuffle、repeat、batch三种操作顺序有讲究：  \n",
    "- 在repeat之后shuffle，会在epoch之间数据随机（当有些数据出现两次的时候，其他数据还没有出现过）\n",
    "- 在batch之后shuffle，会打乱batch的顺序，但是不会在batch之间打乱数据。 \n",
    "\n",
    "（2）shuffle操作时，buffer_size越大，打乱效果越好，但消耗内存越大，可能造成延迟。  \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "推荐通过使用 tf.data.Dataset.apply 方法和融合过的 tf.data.experimental.shuffle_and_repeat 函数来执行这些操作:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = image_label_ds.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))\n",
    "BATCH_SIZE = 32\n",
    "ds = ds.batch(BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "好了，至此，本文内容其实就结束了，因为已经将自定义的图片数据集加载到了Dataset中。\n",
    "\n",
    "下面的内容作为扩展阅读。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 扩展"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上面的方法是简单的在每次epoch迭代中单独读取每个文件，在本地使用 CPU 训练时这个方法是可行的，但是可能不足以进行GPU训练并且完全不适合任何形式的分布式训练。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以使用tf.data.Dataset.cache在epoch迭代过程间缓存计算结果。这能极大提升程序效率，特别是当内存能容纳全部数据时。\n",
    "\n",
    "在被预处理之后（解码和调整大小），图片就被缓存了："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = image_label_ds.cache()  # 缓存\n",
    "ds = ds.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用内存缓存的一个缺点是必须在每次运行时重建缓存，这使得每次启动数据集时有相同的启动延迟。如果内存不够容纳数据，使用一个缓存文件："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = image_label_ds.cache(filename='./cache.tf-data')\n",
    "ds = ds.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 参考"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "https://tensorflow.google.cn/tutorials/load_data/images"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow_gpu",
   "language": "python",
   "name": "tensorflow_gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
